# ------------------------------------------------------------------------------
# Prepare main modeling cohort
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# Pulls from <https://gitlab.com/labsysmed/zolab-projects/stress_test_medicare/-/blob/master/code/02_analysis_construction/08_sample_split.R>
# To run: bsub -q big -R "rusage[mem=30000]" bash 01_prep_modeling_data.sh {split}
# ------------------------------------------------------------------------------

# Setup ------------------------------------------------------------------------
set.seed(1)

library(yaml) # read_yaml absolute filepaths
library(data.table)
library(here) # here() relative filepaths
library(testit) # assert function
library(tibble) # tribble
library(tidyverse)
library(feather)
library(lubridate) # force_tz
library(glue) # glue()
library(optparse) # make_option
library(RSQLite) # connect to xwalk database

u <- modules::use(here::here("lib", "util.R"))

# Command Line Args ------------------------------------------------------------
arg_config <- list(
  make_option("--split", type = "character")
)
arg_parser <- OptionParser(option_list = arg_config)
arg_list <- parse_args(arg_parser)
message("Split: ", arg_list$split)

# Directories ------------------------------------------------------------------
paths <- read_yaml(here::here("lib", "filepaths.yml"))
overnight_lab <- ifelse(arg_list$split == "overnight", "_overnight", "")
split <- arg_list$split
save_dir <- file.path(paths$modeling$dir, "cohorts", arg_list$split)

# Helper Functions -------------------------------------------------------------
combine_outcomes <- function(split_group) {
  message("Preparing ", split_group, " cohort...")
  cohort <- readRDS(paths$cohort[[glue("{split_group}{overnight_lab}")]]) %>%
    mutate(test_010_day = (cohort == "tested")) %>%
    select(
      ptid, ed_enc_id, t0, start_datetime, end_datetime,
      test_010_day, excl_flag_c_int, excl_flag_chronic, excl_flag_death
    )

  merged <- cohort %>%
    u$safe_left_join(adverse_outcomes) %>%
    u$safe_left_join(ed_tests) %>%
    u$safe_left_join(ami)

  assert(
    "Row order same pre- and post-merge",
    all(cohort$ed_enc_id == merged$ed_enc_id)
  )

  cleaned <- merged %>%
    replace_na(
      list(
        stent_010_day = FALSE, cabg_010_day = FALSE,
        stent_or_cabg_010_day = FALSE
      )
    ) %>%
    mutate(
      mace_or_stent = (test_010_day & stent_or_cabg_010_day) |
      (!test_010_day & macetrop_pos_or_death_030)
    ) %>%
    mutate(
      exclude_modeling = excl_flag_c_int | excl_flag_chronic | excl_flag_death |
      (!test_010_day & ami_day_of)
    ) %>%
    setDT()

  if (grepl("train", split_group)) {
    message("Assigning train folds...")
    unique_ids <- u$assign_train_folds(cleaned, nfolds)

    message("Adding downsample flags...")
    # can add additional rows to create other downsampling flags based on outcomes
    downsample_design <- tribble(
      ~outcome, ~target_prop, ~flag_name,
      "macetrop_pos_or_death_030", 0.30, "downsample_keep_mace",
      "test_010_day", 0.30, "downsample_keep_test"
    )

    pwalk(downsample_design, add_downsample_flag, DT = unique_ids)

    train_folds <- unique_ids %>%
      select(ptid, train_fold, downsample_keep_mace, downsample_keep_test)
    cleaned <- u$safe_left_join(cleaned, train_folds) %>%
      mutate(
        non_downsampled = !test_010_day & !downsample_keep_mace &
        !downsample_keep_test
      )

    assert(
      "No null train_folds",
      nrow(filter(cleaned, is.na(train_fold))) == 0
    )
  }

  save_fp <- file.path(save_dir, glue("{split_group}_cohort.rds"))
  message("Saving ", split_group, " to ", save_fp, "...")
  write_rds(cleaned, save_fp)
  message("Saved.")

  return(NA)
}

add_downsample_flag <- function(DT, outcome, target_prop, flag_name,
                                cutoff = NULL) {
  DT <- setDT(DT)

  if (is.null(cutoff)) {
    cutoff <- get_cutoff(DT[train_fold != 0], outcome, target_prop)
  }

  DT[, (flag_name) := get(eval(outcome)) == 1 | rand_stratum_rank <= cutoff]

  return(NULL)
}

get_cutoff <- function(DT, outcome, target, tol = 0.01, lower = 0, upper = 1,
                       verbose = TRUE) {
  cutoff <- (lower + upper) / 2
  if (verbose) message("Trying: ", cutoff)

  weights <- DT[[outcome]] | (DT[["rand_stratum_rank"]] < cutoff)

  observed <- weighted.mean(DT[[outcome]], weights)

  # If we haven"t reached the target we need to drop more
  if (observed < (target - tol)) {
    message("Observed mean: ", round(observed, 2), " is too far from target.")
    return(get_cutoff(DT, outcome, target, tol, lower, cutoff))
  }

  # If we"ve overshot the target we need to drop less
  if (observed > (target + tol)) {
    message("Observed mean: ", round(observed, 2), " is too far from target.")
    return(get_cutoff(DT, outcome, target, tol, cutoff, upper))
  }

  if (verbose) message("Found cutoff: ", cutoff, " with observed mean: ", observed)
  return(cutoff)
}

# Load Data --------------------------------------------------------------------
message("Loading data...")
# death
mydb <- dbConnect(RSQLite::SQLite(), paths$features$db_xwalked)
res <- dbSendQuery(mydb, "SELECT * FROM dem_xwalked")
demos <- dbFetch(res) %>%
  unique
dbClearResult(res)
dbDisconnect(mydb)

death <- demos %>%
  mutate(days_to_death = death_date - t0) %>%
  mutate(death_030_day = replace_na(days_to_death <= 30, FALSE)) %>%
  select(ed_enc_id, death_030_day)

adverse_outcomes <- read_csv(paths$cohort$troponin) %>%
  mutate(
    macetrop_030_pos = `max_troponin-start_date_p1-start_date_p30` > 0
  ) %>%
  mutate(macetrop_030_pos = replace_na(macetrop_030_pos, FALSE)) %>%
  select(ed_enc_id, macetrop_030_pos) %>%
  u$safe_left_join(death) %>%
  mutate(macetrop_pos_or_death_030 = macetrop_030_pos | death_030_day) %>%
  mutate(macetrop_pos_and_death_030 = macetrop_030_pos & death_030_day)

col_types <- cols(
  stent_date = col_datetime(),
  cabg_date = col_datetime()
)
ed_tests <- read_csv(paths$cohort$interventions, col_types = col_types) %>%
  mutate(stent_010_day = as.logical(stent)) %>%
  mutate(cabg_010_day = as.logical(cabg)) %>%
  mutate(stent_or_cabg_010_day = stent_010_day | cabg_010_day) %>%
    select(ed_enc_id, stent_010_day, cabg_010_day, stent_or_cabg_010_day, test_date)

ami <- readRDS(paths$cohort$ami_encounters)

split_groups <- c("train", "val", "test")
nfolds <- 5
save_split <- map(split_groups, combine_outcomes)

# Split and sparsify features --------------------------------------------------
message("Sparsifying train features...")
# random patient level split
ids <- readRDS(file.path(save_dir, "train_cohort.rds"))
x <- readRDS(glue(paths$features$train))
assert("Rows of train IDS and features match", nrow(ids) == nrow(x))
downsamples <- c(
  "downsample_keep_mace", "downsample_keep_test", "test_010_day",
  "non_downsampled"
)
for(subsample in downsamples){
  keep_downsample <- which(ids[[subsample]] == 1)

  ids_ds <- ids[keep_downsample,]
  x_ds <- x[keep_downsample,]
  x_ds <- u$sparsify(x_ds)

  ds_label <- case_when(
    subsample == "downsample_keep_mace" ~ "ds_mace",
    subsample == "downsample_keep_test" ~ "ds_test",
    subsample == "test_010_day" ~ "tested",
    TRUE ~ subsample
  )

  x_filename <- glue("train_features_{ds_label}.rds")
  x_fp <- file.path(paths$features$dir, split, x_filename)
  message("Saving ", subsample, " features to ", x_fp)
  write_rds(x_ds, x_fp)

  ids_filename <- glue("train_cohort_{ds_label}.rds")
  ids_fp <- file.path(save_dir, ids_filename)
  message("Saving ", subsample, " cohort to ", ids_fp)
  write_rds(ids_ds, ids_fp)
}

message("Done.")
